# -*- coding: utf-8 -*-
"""CircularPlots.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1i7C_xjcYUkC3hFf95XH65uD0002KXZ2Y
"""

#@title Importing libraries and downloading data to colab
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import circmean

import pandas as pd
from google.colab import files

uploaded = files.upload()

#@title Pain ahead (do not open if not sure)
class CircPlotter:
    def __init__(self, data, grouped, name, data_in_angles, no_round=0, n_angles=3,
                 groups=None, axial=False):
        if data_in_angles:
            data = data * np.pi/180

        self.data = data
        self.name = name
        self.no_round = no_round
        self.n_angles = n_angles
        self.n_angles = n_angles

        self.axial = axial

        self.grouped = grouped
        if grouped:
            self.groups = groups

        self.sample_size = data.shape[0]

        mean, mean_l = compute_mean(data)
        if self.axial:
            mean, mean_l = compute_mean(data*2)
            mean /=2

        self.mean = mean
        self.mean_l = mean_l
        self.rayleigh_p = rayleigh(mean_l, data.shape[0])
        self.critical_r = compute_critical(data.shape[0])

        self.need_ci = False
        self.alpha = None
        self.ci = None

    def __repr__(self):
        return 'Mean direction = {}\nMean vector = {}\nRayleigh p = {}\n'.format(self.mean * 180/np.pi,
                                                                                 self.mean_l, self.rayleigh_p)

    def computeCI(self, alpha, n_iters = 10000):
        self.need_ci = True
        self.alpha = alpha
        if not self.axial:
            self.ci = bootstrapCI(self.data, alpha, n_iters)
        else:
            ci = bootstrapCI(to_modulo_2pi(self.data*2), alpha, n_iters)
            self.ci = []
            self.ci.append([i/2 for i in ci])
            self.ci.append([i/2 + np.pi for i in ci])

        self.ci = np.array(self.ci)


    def make_plot(self, colors_auto=1, colors=None, grid_visible=1, r_critical_levels=[0.01, 0.05],
                  tick_positions=(0, 90, 180, 270), tick_labels=('gN', '90', '180', '270'),
                  issave=False, savename=None):

        dpi = 100
        const = self.n_angles*np.pi/180
        sam, lengths, lmax, indexes = prepare_sample(self.data.copy(), 0, self.no_round, self.n_angles)

        if self.grouped:
            groups = self.groups[indexes]
            group_labels = np.unique(groups)

        ax = plt.subplot(111, polar=True)
        f = ax.get_figure()
        f.set_dpi(dpi)
        f.set_size_inches((5, 5))
        size = f.get_size_inches()
        ax.set_ylim((0, lmax + const))

        plt.box(on=False)
        ax.set_rgrids((), ())
        ax.grid(grid_visible)

        ax.set_rorigin(0)
        ax.set_theta_zero_location("N")
        ax.set_theta_direction(-1)

        inv = ax.transData.inverted()
        zero = ax.transData.transform((0, 1))
        pi = ax.transData.transform((np.pi, 1))
        msize = self.n_angles*np.pi*np.sum((zero-pi)*(zero-pi))**0.5/(dpi*360)

        ax.plot(np.linspace(0, 2*np.pi, 500), np.ones(500), color='black')
        ax.quiver(0, 0, self.mean, self.mean_l, scale_units='xy', scale=1, angles='xy', zorder=3)
        if self.axial:
          ax.quiver(0, 0, self.mean + np.pi, self.mean_l, scale_units='xy', scale=1, angles='xy', zorder=3)

        if self.grouped:
            for i, label in enumerate(group_labels):
                subsam = sam[np.argwhere(groups == label)]
                sublengths = lengths[np.argwhere(groups == label)]
                if colors_auto:
                    ax.scatter(subsam*np.pi/180, sublengths, msize**2*72*72, edgecolor='black',
                               marker='o', zorder=2, label=label)
                else:
                    ax.scatter(subsam*np.pi/180, sublengths, msize**2*72*72, color=colors[i],
                               edgecolor='black', marker='o', zorder=2, label=label)

            legend_coords = ax.transAxes.inverted().transform([0, size[1]*dpi*0.8])
            ax.legend(loc=tuple(legend_coords))

        else:
            if colors_auto:
                ax.scatter(sam*np.pi/180, lengths, msize**2*72*72, edgecolor='black', marker='o', zorder=2)
            else:
                ax.scatter(sam*np.pi/180, lengths, msize**2*72*72, color=colors[0], edgecolor='black', marker='o', zorder=2)

        if r_critical_levels != None:
            critical_rs = [compute_critical(self.sample_size, i) for i in r_critical_levels]
            for critical_r in critical_rs:
                ax.plot(np.linspace(0, 2*np.pi, 500), np.ones(500)*critical_r, color='black', linestyle='--')

        if self.need_ci:
            if not self.axial:
                ax.plot(np.linspace(self.ci[0], self.ci[1], 500), np.ones(500)*(1-const), color='black')
            else:
                ax.plot(np.linspace(self.ci[0][0], self.ci[0][1], 500), np.ones(500)*(1-const), color='black')
                ax.plot(np.linspace(self.ci[1][0], self.ci[1][1], 500), np.ones(500)*(1-const), color='black')

        ax.set_thetagrids(tick_positions, tick_labels, weight='semibold')
        ax.set_title(self.name)

        if self.need_ci and not self.axial:
            cipos = to_positive(self.ci) * 180/np.pi
            if self.rayleigh_p >= 0.001:
                text_values = [self.sample_size, to_positive(self.mean)*180/np.pi, 100-self.alpha, cipos[0], cipos[1], self.rayleigh_p]
                ax.text(*inv.transform((size[0] * 0, size[1] * 5)),
                        'N = {:.0f}\nmean = {:.2f}\nCI{} = ({:.2f}; {:.2f})\nRayleigh p = {:.3f}'.format(*text_values))
            else:
                text_values = [self.sample_size, to_positive(self.mean)*180/np.pi, 100-self.alpha, cipos[0], cipos[1]]
                ax.text(*inv.transform((size[0] * 0, size[1] * 5)),
                        'N = {:.0f}\nmean = {:.2f}\nCI{} = ({:.2f}; {:.2f})\nRayleigh p < 0.001'.format(*text_values))
        elif self.need_ci:
            cipos = [to_positive(ci) * 180/np.pi for ci in self.ci]
            if self.rayleigh_p >= 0.001:
                text_values = [self.sample_size, to_positive(self.mean)*180/np.pi, 100-self.alpha, cipos[0][0],
                               cipos[0][1], 100-self.alpha, cipos[1][0], cipos[1][1], self.rayleigh_p]

                ax.text(*inv.transform((size[0] * 0, size[1] * 5)),
                        'N = {:.0f}\nmean = {:.2f}\nCI{} = ({:.2f}; {:.2f})\nCI{}_axial = ({:.2f}; {:.2f})\nRayleigh p = {:.3f}'.format(*text_values))
            else:
                text_values = [self.sample_size, to_positive(self.mean)*180/np.pi, 100-self.alpha, cipos[0][0],
                               cipos[0][1], cipos[1][0], cipos[1][1]]
                ax.text(*inv.transform((size[0] * 0, size[1] * 5)),
                        'N = {:.0f}\nmean = {:.2f}\nCI{} = ({:.2f}; {:.2f}) ({:.2f}; {:.2f})\nRayleigh p < 0.001'.format(*text_values))

        else:
            if self.rayleigh_p >= 0.001:
                text_values = [self.sample_size, to_positive(self.mean)*180/np.pi, self.rayleigh_p]
                ax.text(*inv.transform((size[0] * 0, size[1] * 5)),
                        'N = {:.0f}\nmean = {:.2f}\nRayleigh p = {:.3f}'.format(*text_values))
            else:
                text_values = [self.sample_size, to_positive(self.mean)*180/np.pi, self.rayleigh_p]
                ax.text(*inv.transform((size[0] * 0, size[1] * 5)),
                        'N = {:.0f}\nmean = {:.2f}\nRayleigh p < 0.001'.format(*text_values))

        if issave:
            plt.savefig('{}.png'.format(savename), dpi=200)

        plt.show()


def prepare_sample(data, data_in_angles, no_round, n_angles):
    const = n_angles*np.pi/180
    critical_diff = n_angles/2
    lengths = np.ones_like(data) + n_angles*np.pi/360

    if not data_in_angles:
        data *= 180/np.pi

    if not no_round:
        round_const = n_angles * 0.7
        data = np.round(data/round_const)*round_const

    indexes = data.argsort()
    data = data[indexes]
    diffs = np.diff(data, prepend=data[-1])
    diffs = np.abs(diffs)
    diffs = (diffs < critical_diff) + (diffs > (360-critical_diff))
    for i, b in enumerate(diffs):
        if b:
            lengths[i] = lengths[i-1] + const

    lmax = max(np.unique(lengths))

    return data, lengths, lmax, indexes

def bootstrapCI(data, alpha=5, n_iters=10000):
    means = []
    for i in range(n_iters):
        shuffle = np.random.choice(data, data.shape[0], replace=True)
        A, r = compute_mean(shuffle)
        means.append(A)

    means = np.array(means)
    bootstrap_mean = circmean(means)
    means = means-bootstrap_mean
    means -= (means > np.pi) * 2*np.pi
    means += (means < -np.pi) * 2*np.pi
    ci_low = np.percentile(means, alpha/2) + bootstrap_mean
    ci_high = np.percentile(means, 100 - alpha/2) + bootstrap_mean

    return np.array((ci_low, ci_high))

def rayleigh(r, n):
    return np.exp(-n * r**2)

def to_positive(data):
    data = data - 2*np.pi * (data > 2*np.pi) + 2*np.pi * (data < 0)
    return data

def to_modulo_2pi(sample):
    modulo = sample % (2*np.pi)
    return modulo

def compute_mean(data):
    A = circmean(data)
    C = np.sum(np.cos(data))
    S = np.sum(np.sin(data))
    r = (S**2 + C**2)**0.5 / data.shape[0]
    return A, r

def compute_critical(n, alpha=0.05):
    critical = (-np.log(alpha)/n)**0.5
    return critical

#@title Load data to notebook
axial = False #@param {type: 'boolean'}
data_is_grouped = False #@param {type: 'boolean'}
data_in_angles = True #@param {type: 'boolean'}
compute_bootstrap_CI = True #@param {type: 'boolean'}
#@markdown (It will actually compute confidence intervals only if Rayleigh p < 0.05)
file_name = 'ex_data.xlsx' #@param {type: 'string'}
plot_title = 'Autumn 2024_ClockShift+6h' #@param {type: 'string'}

data = pd.read_excel(file_name).dropna()
sample = data.iloc[:, 0].to_numpy()

if data_is_grouped:
  groups = data.iloc[:, 1].to_numpy()
else:
  groups=None

ex_plotter = CircPlotter(sample, name=plot_title, grouped=data_is_grouped,
                         data_in_angles=data_in_angles, groups=groups,
                         n_angles=3, axial=axial)
print(ex_plotter)
if (ex_plotter.rayleigh_p < 0.05) and compute_bootstrap_CI:
    ex_plotter.computeCI(5)

ex_plotter.make_plot(r_critical_levels=[0.05], grid_visible=1, colors_auto=1)

#@title Make final plot and save it!
marker_size = 4 #@param {type: 'integer'}
savename = "LedSE" #@param {type: 'string'}
if savename == "":
  savename = None

r_critical_levels = '0.05 0.01' #@param {type: 'string'}
r_critical_levels = [float(i) for i in r_critical_levels.split()]
grid_visible = False #@param {type: 'boolean'}
colors_auto = False #@param {type: 'boolean'}
if not colors_auto:
  colors = 'blue' #@param {type: 'string'}
  colors = colors.split()

xticks = '0 90 180 270' #@param {type: 'string'}
xticks = [int(i) for i in xticks.split()]
xlabs = 'mN/gN 90 180 270' #@param {type: 'string'}
xlabs = xlabs.split()

ex_plotter.n_angles = marker_size
ex_plotter.make_plot(r_critical_levels=r_critical_levels, grid_visible=grid_visible,
                     colors_auto=colors_auto, colors=colors,
                     tick_positions=xticks, tick_labels=xlabs, savename=savename,
                     issave=True)

download_graph = True #@param {type: 'boolean'}
if download_graph:
  files.download('{}.png'.format(savename))

